#!/usr/bin/env python

import os
import time
import datetime
import argparse
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
from transformers import CLIPTokenizer
import importlib.util

# Import the custom sd module
def load_sd_module(sd_path):
    spec = importlib.util.spec_from_file_location("sd", sd_path)
    sd_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(sd_module)
    return sd_module

class T2ILogQuantizer(nn.Module):
    def __init__(self, bits=8):
        super().__init__()
        self.bits = bits
        self.qmax = 2**bits - 1
        
    def forward(self, x):
        """Log-based quantization for attention weights"""
        if x.numel() == 0:
            return x
            
        # Store original dtype
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        
        # Get scale (max value)
        scale = x.max()
        if scale == 0:
            return x.to(orig_dtype)
            
        # Log quantization
        x_normalized = x / scale
        # Add small epsilon to avoid log(0)
        x_log = torch.log2(x_normalized + 1e-10) * -1
        x_quantized = torch.round(x_log)
        x_quantized = torch.clamp(x_quantized, 0, self.qmax)
        
        # Dequantize
        x_dequant = torch.exp2(x_quantized * -1) * scale
        
        return x_dequant.to(orig_dtype)

class T2IUniformQuantizer(nn.Module):
    def __init__(self, bits=8):
        super().__init__()
        self.bits = bits
        self.qmax = 2**bits - 1
        
    def forward(self, x):
        """Uniform quantization"""
        if x.numel() == 0:
            return x
            
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        
        xmin = x.min()
        xmax = x.max()
        if xmin == xmax:
            return x.to(orig_dtype)
            
        scale = (xmax - xmin) / self.qmax
        zero_point = xmin
        
        x_quantized = torch.round((x - zero_point) / scale)
        x_quantized = torch.clamp(x_quantized, 0, self.qmax)
        x_dequant = x_quantized * scale + zero_point
        
        return x_dequant.to(orig_dtype)

def patch_attention_forward_tokenwise(pipe, sd_module, quant_bits=8, token_indices=None, 
                                     use_log_quant=True, quantize_kv=True, quantize_w=True,
                                     kv_quant_bits=None, debug=False):
    """
    Patches attention layers to use custom forward with token-wise quantization
    
    Args:
        pipe: StableDiffusion pipeline
        sd_module: The loaded sd.py module
        quant_bits: Number of bits for quantization (for attention weights)
        token_indices: List of token indices to preserve (not quantize)
        use_log_quant: Use log quantization if True, uniform if False
        quantize_kv: Whether to quantize K and V tensors (often causes issues)
        quantize_w: Whether to quantize attention weights
        kv_quant_bits: Separate bit setting for K/V quantization (defaults to quant_bits)
        debug: Print debug information
    """
    
    # Use separate bits for KV if specified
    kv_bits = kv_quant_bits if kv_quant_bits is not None else quant_bits
    
    # Create quantizers
    if use_log_quant:
        quantizer_k = T2ILogQuantizer(bits=kv_bits)
        quantizer_v = T2ILogQuantizer(bits=kv_bits)
        quantizer_w = T2ILogQuantizer(bits=quant_bits)
    else:
        quantizer_k = T2IUniformQuantizer(bits=kv_bits)
        quantizer_v = T2IUniformQuantizer(bits=kv_bits)
        quantizer_w = T2IUniformQuantizer(bits=quant_bits)
    
    def create_attention_forward(original_forward, preserve_token_indices):
        def new_forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs):
            # Determine if this is cross-attention (has encoder_hidden_states)
            is_cross_attn = encoder_hidden_states is not None
            
            # Get Q, K, V
            q = self.to_q(hidden_states)
            k = self.to_k(encoder_hidden_states) if is_cross_attn else self.to_k(hidden_states)
            v = self.to_v(encoder_hidden_states) if is_cross_attn else self.to_v(hidden_states)
            
            b, t, c = q.size()
            
            # Reshape for attention
            q = q.view(b, t, self.heads, c // self.heads).transpose(1, 2)
            k = k.view(b, k.size(1), self.heads, c // self.heads).transpose(1, 2)
            v = v.view(b, v.size(1), self.heads, c // self.heads).transpose(1, 2)
            
            if debug and is_cross_attn:
                print(f"Before quant - K stats: min={k.min():.4f}, max={k.max():.4f}, mean={k.mean():.4f}")
                print(f"Before quant - V stats: min={v.min():.4f}, max={v.max():.4f}, mean={v.mean():.4f}")
            
            # Apply token-wise quantization to K and V for cross-attention (OPTIONAL)
            if quantize_kv and is_cross_attn and preserve_token_indices is not None and len(preserve_token_indices) > 0:
                preserve_indices = sorted(preserve_token_indices)
                seq_len = k.size(-2)
                
                # Create boolean mask for all tokens
                mask = torch.ones(seq_len, dtype=torch.bool, device=k.device)
                for idx in preserve_indices:
                    if idx < seq_len:
                        mask[idx] = False
                
                # Only proceed if there are tokens to quantize
                if mask.any():
                    # Process K
                    k_fp32 = k.to(torch.float32)
                    k_preserved = k_fp32[..., ~mask, :].clone()  # Preserve these tokens
                    k_to_quantize = k_fp32[..., mask, :]  # Quantize these tokens
                    
                    if k_to_quantize.numel() > 0:
                        k_quantized = quantizer_k(k_to_quantize)
                        # Reconstruct K
                        k_new = torch.zeros_like(k_fp32)
                        k_new[..., ~mask, :] = k_preserved
                        k_new[..., mask, :] = k_quantized
                        k = k_new.to(k.dtype)
                    
                    # Process V
                    v_fp32 = v.to(torch.float32)
                    v_preserved = v_fp32[..., ~mask, :].clone()  # Preserve these tokens
                    v_to_quantize = v_fp32[..., mask, :]  # Quantize these tokens
                    
                    if v_to_quantize.numel() > 0:
                        v_quantized = quantizer_v(v_to_quantize)
                        # Reconstruct V
                        v_new = torch.zeros_like(v_fp32)
                        v_new[..., ~mask, :] = v_preserved
                        v_new[..., mask, :] = v_quantized
                        v = v_new.to(v.dtype)
                    
                    if debug:
                        print(f"After KV quant - K stats: min={k.min():.4f}, max={k.max():.4f}, mean={k.mean():.4f}")
                        print(f"After KV quant - V stats: min={v.min():.4f}, max={v.max():.4f}, mean={v.mean():.4f}")
            
            # Compute attention scores
            scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
            attn_weights = torch.softmax(scores, dim=-1)
            
            if debug and is_cross_attn:
                print(f"Attention weights stats: min={attn_weights.min():.4f}, max={attn_weights.max():.4f}")
            
            # Apply token-wise quantization to attention weights for cross-attention
            if quantize_w and is_cross_attn and preserve_token_indices is not None and len(preserve_token_indices) > 0:
                # Convert to float32 for quantization
                attn_weights_fp32 = attn_weights.to(torch.float32)
                
                preserve_indices = sorted(preserve_token_indices)
                seq_len = attn_weights_fp32.size(-1)
                
                # Create boolean mask
                mask = torch.ones(seq_len, dtype=torch.bool, device=attn_weights.device)
                for idx in preserve_indices:
                    if idx < seq_len:
                        mask[idx] = False
                
                if mask.any():
                    # Split weights
                    weights_preserved = attn_weights_fp32[..., ~mask].clone()
                    weights_to_quantize = attn_weights_fp32[..., mask]
                    
                    if weights_to_quantize.numel() > 0:
                        weights_quantized = quantizer_w(weights_to_quantize)
                        
                        # Reconstruct weights
                        weights_new = torch.zeros_like(attn_weights_fp32)
                        weights_new[..., ~mask] = weights_preserved
                        weights_new[..., mask] = weights_quantized
                        attn_weights = weights_new.to(attn_weights.dtype)
                    
                    if debug:
                        print(f"After weight quant - stats: min={attn_weights.min():.4f}, max={attn_weights.max():.4f}")
            
            # Apply attention to values
            attn_output = torch.matmul(attn_weights, v)
            attn_output = attn_output.transpose(1, 2).contiguous().view(b, t, c)
            
            # Apply output projection
            for layer in self.to_out:
                attn_output = layer(attn_output)
            
            if debug and is_cross_attn:
                print(f"Output stats: min={attn_output.min():.4f}, max={attn_output.max():.4f}, mean={attn_output.mean():.4f}")
                print("-" * 50)
            
            return attn_output
        
        return new_forward
    
    # Patch all attention layers in the UNet
    for name, module in pipe.unet.named_modules():
        if module.__class__.__name__ == "Attention":
            # Store original forward
            if not hasattr(module, '_original_forward'):
                module._original_forward = module.forward
            # Apply new forward
            module.forward = create_attention_forward(module._original_forward, token_indices).__get__(module, module.__class__)

    
def unpatch_attention(pipe):
    """Restore original attention forward methods"""
    for name, module in pipe.unet.named_modules():
        if module.__class__.__name__ == "Attention":
            if hasattr(module, '_original_forward'):
                module.forward = module._original_forward

def generate_tokenwise_quantized_images(
    pipe, 
    prompt, 
    sd_module,
    output_dir,
    quant_bits=8,
    use_log_quant=True,
    seed=42,
    num_inference_steps=25,
    guidance_scale=7.5
):
    """
    Generate images with token-wise quantization
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Tokenize the prompt to understand token boundaries
    tokenizer = pipe.tokenizer
    tokens = tokenizer.tokenize(prompt)
    token_ids = tokenizer.encode(prompt, add_special_tokens=True)
    
    print(f"\nPrompt: '{prompt}'")
    print(f"Tokens: {tokens}")
    print(f"Token IDs: {token_ids}")
    print(f"Total tokens (including special): {len(token_ids)}")
    
    # Generate base image without quantization
    print("\nGenerating base image (no quantization)...")
    unpatch_attention(pipe)
    generator = torch.Generator(device="cuda").manual_seed(seed)
    base_image = pipe(
        prompt,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator
    ).images[0]
    base_path = os.path.join(output_dir, "base_no_quant.png")
    base_image.save(base_path)
    print(f"Saved: {base_path}")
    
    # Generate images with each token quantized
    results = []
    for i, (token, token_id) in enumerate(zip(tokens, token_ids[1:-1])):  # Skip special tokens
        actual_idx = i + 1  # Account for start token
        
        print(f"\nQuantizing token {actual_idx}: '{token}' (ID: {token_id})")
        
        # Apply quantization to this specific token
        patch_attention_forward_tokenwise(
            pipe, 
            sd_module, 
            quant_bits=quant_bits,
            token_indices=[actual_idx],
            use_log_quant=use_log_quant
        )
        
        # Generate image
        generator = torch.Generator(device="cuda").manual_seed(seed)
        image = pipe(
            prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            generator=generator
        ).images[0]
        
        # Save image
        quant_type = "log" if use_log_quant else "uniform"
        filename = f"token_{i:02d}_{token.replace('/', '_')}_{quant_type}_{quant_bits}bit.png"
        filepath = os.path.join(output_dir, filename)
        image.save(filepath)
        print(f"Saved: {filepath}")
        
        results.append({
            'token': token,
            'index': actual_idx,
            'filepath': filepath
        })
    
    # Generate image with all tokens quantized
    print(f"\nQuantizing all tokens except special tokens...")
    all_token_indices = list(range(1, len(token_ids) - 1))
    patch_attention_forward_tokenwise(
        pipe,
        sd_module,
        quant_bits=quant_bits,
        token_indices=all_token_indices,
        use_log_quant=use_log_quant
    )
    
    generator = torch.Generator(device="cuda").manual_seed(seed)
    all_quant_image = pipe(
        prompt,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator
    ).images[0]
    
    quant_type = "log" if use_log_quant else "uniform"
    all_quant_path = os.path.join(output_dir, f"all_tokens_{quant_type}_{quant_bits}bit.png")
    all_quant_image.save(all_quant_path)
    print(f"Saved: {all_quant_path}")
    
    # Restore original attention
    unpatch_attention(pipe)
    
    return results

def create_comparison_grid(output_dir, prompt):
    """Create a grid showing all generated images for easy comparison"""
    from PIL import Image, ImageDraw, ImageFont
    import glob
    
    # Get all generated images
    image_files = sorted(glob.glob(os.path.join(output_dir, "*.png")))
    if not image_files:
        print("No images found to create grid")
        return
    
    # Load images
    images = []
    labels = []
    for filepath in image_files:
        img = Image.open(filepath)
        images.append(img)
        # Extract label from filename
        filename = os.path.basename(filepath)
        if "base" in filename:
            label = "Base (No Quant)"
        elif "all_tokens" in filename:
            label = "All Tokens"
        else:
            # Extract token from filename
            parts = filename.split('_')
            if len(parts) > 2:
                token = parts[2]
                label = f"Token: {token}"
            else:
                label = filename.split('.')[0]
        labels.append(label)
    
    # Create grid
    img_width, img_height = images[0].size
    cols = min(4, len(images))
    rows = (len(images) + cols - 1) // cols
    
    grid_width = cols * img_width
    grid_height = rows * img_height + 30  # Extra space for title
    
    grid = Image.new('RGB', (grid_width, grid_height), 'white')
    draw = ImageDraw.Draw(grid)
    
    # Add title
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20)
    except:
        font = ImageFont.load_default()
    
    title = f"Token-wise Quantization: {prompt[:50]}..."
    draw.text((10, 5), title, fill='black', font=font)
    
    # Paste images
    for idx, (img, label) in enumerate(zip(images, labels)):
        row = idx // cols
        col = idx % cols
        x = col * img_width
        y = row * img_height + 30
        grid.paste(img, (x, y))
        
        # Add label
        try:
            label_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
        except:
            label_font = ImageFont.load_default()
        draw.text((x + 5, y + 5), label, fill='white', font=label_font)
        draw.text((x + 4, y + 4), label, fill='black', font=label_font)
    
    # Save grid
    grid_path = os.path.join(output_dir, "_comparison_grid.png")
    grid.save(grid_path)
    print(f"\nComparison grid saved: {grid_path}")

def main():
    parser = argparse.ArgumentParser(description="Token-wise quantization for Stable Diffusion")
    parser.add_argument("--model_path", type=str, required=True, help="Path to Stable Diffusion model")
    parser.add_argument("--sd_module_path", type=str, default="diffusers_rewrite/sd.py", help="Path to sd.py module")
    parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation")
    parser.add_argument("--output_dir", type=str, default="tokenwise_output", help="Output directory")
    parser.add_argument("--quant_bits", type=int, default=4, help="Number of quantization bits")
    parser.add_argument("--quantizer", type=str, choices=["log", "uniform"], default="log", help="Quantization method")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--steps", type=int, default=25, help="Number of inference steps")
    parser.add_argument("--guidance_scale", type=float, default=7.5, help="Guidance scale")
    
    args = parser.parse_args()
    
    # Create timestamped output directory
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = os.path.join(args.output_dir, f"run_{timestamp}")
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"Output directory: {output_dir}")
    
    # Load SD module
    print(f"Loading SD module from: {args.sd_module_path}")
    sd_module = load_sd_module(args.sd_module_path)
    
    # Load pipeline
    print(f"Loading Stable Diffusion pipeline from: {args.model_path}")
    
    unet = UNet2DConditionModel.from_pretrained(
        args.model_path + "/checkpoint-700", 
        subfolder="unet"
    ).to("cuda")

    pipe = StableDiffusionPipeline.from_pretrained(
        args.model_path,
    ).to("cuda")
    
    pipe.unet = unet
    pipe.safety_checker = None
    pipe.requires_safety_checker = False
    
    # Generate images with token-wise quantization
    results = generate_tokenwise_quantized_images(
        pipe=pipe,
        prompt=args.prompt,
        sd_module=sd_module,
        output_dir=output_dir,
        quant_bits=args.quant_bits,
        use_log_quant=(args.quantizer == "log"),
        seed=args.seed,
        num_inference_steps=args.steps,
        guidance_scale=args.guidance_scale
    )
    
    # Create comparison grid
    create_comparison_grid(output_dir, args.prompt)
    
    # Save metadata
    metadata_path = os.path.join(output_dir, "metadata.txt")
    with open(metadata_path, 'w') as f:
        f.write(f"Prompt: {args.prompt}\n")
        f.write(f"Quantization: {args.quantizer} {args.quant_bits}-bit\n")
        f.write(f"Seed: {args.seed}\n")
        f.write(f"Steps: {args.steps}\n")
        f.write(f"Guidance Scale: {args.guidance_scale}\n")
        f.write(f"Model: {args.model_path}\n")
        f.write(f"Timestamp: {timestamp}\n")
    
    print(f"\nAll outputs saved to: {output_dir}")
    print("Done!")

if __name__ == "__main__":
    main()